# -*- coding: utf-8 -*-
"""
Created on Mon Apr 10 23:20:56 2023

@author: 29672366
"""

import os
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from utils.JiebaTokenizer import JiebaTokenizer

tokenizer = JiebaTokenizer()

from models.TransformerQA import TransformerQA
from utils.QADataset import QADataset

traindata_byglm = "./traindatas/QA/train/fromglm"
traindata_bygpt = "./traindatas/QA/train/fromgpt"
traindata_interview = "./traindatas/QA/train/interview"

data_paths = [(traindata_byglm,False),
              (traindata_bygpt,False),
              (traindata_interview,True)]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = 'cpu'
print("Using device:", device)

# 模型参数
vocab_size = tokenizer.vocab_size
embedding_dim = 96
hidden_dim = 1024
num_layers = 3
num_heads = 12
max_seq_len = 512
model_save_path = f'./run/{embedding_dim}_{hidden_dim}_{num_layers}_{num_heads}'


def collate_fn(batch):
    inputs = []
    targets = []
    for sample in batch:
        inputs.append(sample[0])
        targets.append(sample[1])
        
    inputs = pad_sequence(inputs, batch_first=True, padding_value=tokenizer.pad_token_id).to(device)
    targets = pad_sequence(targets, batch_first=True, padding_value=tokenizer.pad_token_id).to(device)
    torch.cuda.empty_cache()
    return inputs, targets

def train(model, dataset, batch_size, epochs, optimizer, criterion, device, model_save_path, resume=False):
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)
    best_loss = float('inf')
    if resume:
        # If resume is True, load the best model from the saved checkpoint and get its best_loss
        resume_path = os.path.join(model_save_path, 'best_model.pth')
        print(resume_path)
        if os.path.exists(resume_path):
            checkpoint = torch.load(resume_path,map_location=torch.device(device))
            best_loss = checkpoint['loss']
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print(f"Loaded checkpoint from {resume_path}, best_loss: {best_loss}")
            del checkpoint
            import gc
            gc.collect()
        else:
            print(f"No checkpoint found at {resume_path}, starting from scratch")

    model.to(device)
    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.to(device)

    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    total_steps = len(train_loader)
    full_save = False
    tech = 0.75
    temperature = 0.75
    for epoch in range(epochs):
        epoch_loss = 0
        for i, (input_seq_list, target_seq_list) in enumerate(train_loader):
            optimizer.zero_grad() 
            # 显示目标真实标签
            #print("训练输入: " + str(tokenizer.decode(input_seq_list.view(-1).tolist())))
            output_seq_digits = model(input_seq_list, target_seq_list, teachforcing = tech,temperature = temperature)
            #if target_seq_list.size(0) < max_seq_len:#已加在数据集中
            #    target_seq_list = torch.cat([target_seq_list, torch.tensor([[0]], dtype=torch.int64).to(target_seq_list.device)], dim=1) 
            
            #print("训练目标: " + str(tokenizer.decode(target_seq_list.view(-1).tolist())))
            # 将输出序列转换为自然语言形式输出
            #output_seq = output_seq_digits.argmax(dim=2).squeeze(1)
            #print("训练结果序列: " + str(output_seq.view(-1).tolist()))
            #output_text = tokenizer.decode(output_seq.view(-1).tolist())
            #output_text = "".join(output_text)
            #print("训练结果: " + output_text)
            loss = criterion(output_seq_digits.view(-1, output_seq_digits.size(-1)), target_seq_list.view(-1))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{total_steps}], Loss: {loss.item():.4f}")
        
        epoch_loss /= total_steps
    
        # Save the model
        #model_name = f"model-epoch{epoch+1}-loss{epoch_loss:.4f}.pth"
        model_name = f'saved_model_tech{tech:.2f}.pth'
        model_path = os.path.join(model_save_path, model_name)
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            best_model_path = os.path.join(model_save_path, 'best_model.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_loss,
            }, best_model_path)
            if full_save ==  True:
                best_model_path = os.path.join(model_save_path, f'best_model_tech{tech:.2f}_loss{best_loss:.4f}.pth')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': best_loss,
                }, best_model_path)
            print(f"Epoch [{epoch+1}/{epochs}] finished, saved best model to {best_model_path} avg loss{epoch_loss:.4f}")
        else:
            print(f"Epoch [{epoch+1}/{epochs}] finished, model not saved as loss did not improve")
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': epoch_loss,
        }, model_path)
        
        print(f"Saved epoch {epoch+1} model to {model_path}")


if __name__ == '__main__':
   
    # 构建模型
    model = TransformerQA(vocab_size, embedding_dim, hidden_dim, num_layers, num_heads,max_seq_len)
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.999),weight_decay = 0.01)
    # 加载数据集
    train_dataset = QADataset(data_paths = data_paths, max_seq_length = max_seq_len,tokenizer = tokenizer)
    # 开始训练
    train(model, train_dataset, batch_size=1, epochs=1000, optimizer=optimizer, criterion=criterion, device=device, model_save_path=model_save_path, resume=True)